"""
Code plots figures and creates statistics for the paper. Code is written in python 3.

by: Jacob Malone (j.malone@cablelabs.com)
date: August 25, 2021
"""


import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

pd.options.display.max_rows = 1000
pd.options.display.max_columns = 1000
pd.options.display.width = None

# Define working directory and load files from the "03_id_cord_cutters.py" files
data_dir = '/mnt/et01/cord_cutting_rep'

cord_cutters = pd.read_parquet(f'{data_dir}/04_cord_cutters.parquet')
other_subs = pd.read_parquet(f'{data_dir}/04_other_subs.parquet')
video_dates = pd.read_parquet(f'{data_dir}/04_video_dates.parquet')

#####

# Calculate average daily usage for subs who didn't cord-cut
other_subs_daily_avgs = other_subs.groupby('date')[['tot_gb', 'level_video', 'level_browsing']].mean().reset_index()
other_subs_daily_avgs.columns = ['date', 'other_subs_tot_gb_avg', 'other_subs_vid_avg', 'other_subs_brows_avg']

# Compare cord-cutter daily usage to other subscribers
cord_cutters = cord_cutters.merge(other_subs_daily_avgs, on='date', how='left')
cord_cutters['tot_gb_diff'] = cord_cutters.tot_gb - cord_cutters.other_subs_tot_gb_avg
cord_cutters['level_video_diff'] = cord_cutters.level_video - cord_cutters.other_subs_vid_avg

cord_cutters = cord_cutters.merge(video_dates, on='customer_key', how='left')
cord_cutters['days_to_cut'] = cord_cutters.date - cord_cutters.min_dt_wo_video
cord_cutters['week_num'] = cord_cutters.date.dt.week
cord_cutters['week_of_cut'] = cord_cutters.min_dt_wo_video.dt.week
cord_cutters['weeks_to_cut'] = cord_cutters.week_num - cord_cutters.week_of_cut
cord_cutters['days_w_video'] = cord_cutters.max_dt_w_video - cord_cutters.min_dt_w_video
cord_cutters['days_wo_video'] = cord_cutters.max_dt_wo_video - cord_cutters.min_dt_wo_video

# Restrict sample to only include cord-cutters with 7 days on either side
cord_cutters = cord_cutters[((cord_cutters.days_w_video >= pd.Timedelta(days=7)) &
                             (cord_cutters.days_wo_video >= pd.Timedelta(days=7)))].copy()
print(cord_cutters.customer_key.nunique())

# Consolidate usage groups for plotting
print('Consolidating usage groups')
cord_cutters['level_vidother'] += cord_cutters.level_streaming + cord_cutters.level_viddish
cord_cutters['level_streaming'] += cord_cutters.level_video
cord_cutters['level_other'] += cord_cutters[['level_admin',
                                             'level_cdn',
                                             'level_sharing',
                                             'level_tunnel']].sum(axis=1)
cord_cutters = cord_cutters.drop(['level_admin',
                                  'level_cdn',
                                  'level_sharing',
                                  'level_tunnel',
                                  'level_viddish'], axis=1)

# Define variable lists of overall usage categories and video usage categories
level_vars = ['level_backup',
              'level_browsing',
              'level_gaming',
              'level_music',
              'level_other',
              'level_streaming']
video_vars = ['level_vidamazon',
              'level_vidflash',
              'level_vidhbo',
              'level_vidhulu',
              'level_vidnetflix',
              'level_vidother',
              'level_vidslingtv',
              'level_vidyoutube']
reduce_vars = \
    level_vars + video_vars + ['other_subs_tot_gb_avg', 'other_subs_vid_avg',
                               'other_subs_brows_avg', 'tot_gb', 'level_video']

weekly_dat = cord_cutters.groupby(['customer_key', 'weeks_to_cut'])[reduce_vars].mean().reset_index()
weekly_dat_8weeks = weekly_dat[(weekly_dat.weeks_to_cut >= -8) & (weekly_dat.weeks_to_cut <= 8)].copy()  # Reduce to timeframe of figure

#####

# FIGURE 1 PLOT
weekly_avgs = weekly_dat_8weeks.groupby(['weeks_to_cut'])[reduce_vars].mean().reset_index()


def plot_weekly_before_after_diff_in_diff_fig_just_reg_lines(pframe, filename):

    cc_data = pframe.loc[:, ['weeks_to_cut', 'tot_gb', 'level_streaming', 'level_browsing']].copy()
    others_data = pframe.loc[:, ['weeks_to_cut', 'other_subs_tot_gb_avg',
                                 'other_subs_vid_avg', 'other_subs_brows_avg']].copy()

    cc_data = cc_data.set_index('weeks_to_cut')
    cc_data.columns = ['Total', 'Streaming', 'Browsing']
    cc_data = cc_data.stack()
    cc_data = cc_data.reset_index().rename(columns={'level_1': 'usage_type', 0: 'gb'})

    others_data = others_data.set_index('weeks_to_cut')
    others_data.columns = ['Total', 'Streaming', 'Browsing']
    others_data = others_data.stack()
    others_data = others_data.reset_index().rename(columns={'level_1': 'usage_type', 0: 'gb'})

    f, ax = plt.subplots()

    ax.set_xlim([-8.5, 8.5])
    ax.set_xticks([-8, -6, -4, -2, 0, 2, 4, 6, 8])

    ax.set_ylim([0, 8])
    ax.set_yticks([0, 2, 4, 6, 8])

    ax.set_ylabel('Average Daily Usage (Gigabytes)')
    ax.set_xlabel('Weeks Relative to Cord-Cut')

    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    ax.annotate('Cord-\nCutters (solid)', (-8, 5.3), ha='left', va='center', size=10)
    ax.annotate('Other Subs (dashed)', (-8, 3.65), ha='left', va='center', size=10)

    # Plot the regression lines
    cats = ['Total', 'Streaming', 'Browsing']
    colors = sns.color_palette('Greys_r')[::2]
    for i, cidx in enumerate(cats):
        # Cord-cutters
        cdata = cc_data.loc[cc_data.usage_type == cidx].copy()
        pre_cdata = cdata.loc[cdata.weeks_to_cut < 0].copy()

        pre_x = pre_cdata.weeks_to_cut.values
        pre_y = pre_cdata.gb.values
        pre_m, pre_b = np.polyfit(pre_x, pre_y, 1)
        ax.plot(pre_x, pre_x * pre_m + pre_b, color=colors[i], linestyle='solid', lw=3)
        
        # Other Subs
        odata = others_data.loc[others_data.usage_type == cidx].copy()
        pre_odata = odata.loc[odata.weeks_to_cut < 0].copy()

        pre_x = pre_odata.weeks_to_cut.values
        pre_y = pre_odata.gb.values
        pre_m, pre_b = np.polyfit(pre_x, pre_y, 1)
        ax.plot(pre_x, pre_x * pre_m + pre_b, color=colors[i], linestyle='dashed', lw=3)

    ax.legend(ax.lines[:6][::2], ['Total', 'Streaming', 'Browsing'], loc='upper left')

    for i, cidx in enumerate(cats):
        # Cord-cutters
        cdata = cc_data.loc[cc_data.usage_type == cidx].copy()
        post_cdata = cdata.loc[cdata.weeks_to_cut > 0].copy()

        post_x = post_cdata.weeks_to_cut.values
        post_y = post_cdata.gb.values
        post_m, post_b = np.polyfit(post_x, post_y, 1)
        ax.plot(post_x, post_x * post_m + post_b, color=colors[i], linestyle='solid', lw=3)
        
        # Other Subs
        odata = others_data.loc[others_data.usage_type == cidx].copy()
        post_odata = odata.loc[odata.weeks_to_cut > 0].copy()

        post_x = post_odata.weeks_to_cut.values
        post_y = post_odata.gb.values
        post_m, post_b = np.polyfit(post_x, post_y, 1)
        ax.plot(post_x, post_x * post_m + post_b, color=colors[i], linestyle='dashed', lw=3)

    ax.yaxis.grid(linewidth=0.5, linestyle=':')
    ax.xaxis.grid(linewidth=0.5, linestyle=':')
    ax.set_axisbelow(True)

    f.savefig(f'./figs/{filename}.eps', bbox_inches='tight', dpi=300)
    f.savefig(f'./figs/{filename}.png', bbox_inches='tight', dpi=300)

    plt.close(f)


plot_weekly_before_after_diff_in_diff_fig_just_reg_lines(weekly_avgs, 'fig4_before_after_weekly_comp_just_reg_lines')

#####

video_avgs_8weeks = weekly_dat_8weeks.groupby(['customer_key', 'period_desc'])[video_vars].mean().reset_index()
video_avgs_8weeks = video_avgs_8weeks.groupby('period_desc')[video_vars].mean()
video_avgs_8weeks = video_avgs_8weeks.transpose()
video_avgs_8weeks['gb_diff'] = video_avgs_8weeks.after - video_avgs_8weeks.before
video_avgs_8weeks['pct_change'] = video_avgs_8weeks.gb_diff / video_avgs_8weeks.before
video_avgs_8weeks['pct_total_change'] = video_avgs_8weeks.gb_diff / video_avgs_8weeks.gb_diff.sum()
video_avgs_8weeks = video_avgs_8weeks.reset_index()
video_avgs_8weeks['usage_cat'] = video_avgs_8weeks['index'].str[9:].str.title()
video_avgs_8weeks.loc[video_avgs_8weeks.usage_cat == 'Slingtv', 'usage_cat'] = 'Sling TV'
video_avgs_8weeks.loc[video_avgs_8weeks.usage_cat == 'Hbo', 'usage_cat'] = 'HBO GO'
video_avgs_8weeks.loc[video_avgs_8weeks.usage_cat == 'Youtube', 'usage_cat'] = 'YouTube'


# FIGURE 2, panel (a)
def plot_ottv_before_after_bar(pframe, filename):

    pframe = pframe.loc[:, ['usage_cat', 'before', 'after']].copy()
    pframe = pframe.set_index('usage_cat').stack().reset_index().rename(columns={0: 'gb'})
    pframe = pframe.sort_values(['usage_cat', 'period_desc'], ascending=[True, False])
    pframe['pct_change'] = pframe.groupby('usage_cat').gb.pct_change() * 100
    pct_change_lookup = pframe.dropna(subset=['pct_change']).reset_index(drop=True).to_dict(orient='index')

    f, ax = plt.subplots()

    sns.barplot(data=pframe, x='usage_cat', y='gb', hue='period_desc', hue_order=['before', 'after'], ax=ax,
                palette='Greys')

    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    ax.yaxis.grid(linewidth=0.5, linestyle=':')
    ax.set_axisbelow(True)

    ax.set_ylabel('Average Daily Usage (Gigabytes)')
    ax.set_xlabel('')

    patches = ax.patches[::pframe.usage_cat.nunique()]
    ax.legend(patches,
              ['Before MSO TV Drop',
               'After MSO TV Drop'],
              loc='upper left')

    for idx, i in enumerate(ax.patches[pframe.usage_cat.nunique():]):
        annot_value = int(pct_change_lookup[idx]['pct_change'])
        if annot_value > 0:
            annot_value = f'+{annot_value}%'
        else:
            annot_value = f'{annot_value}%'

        ax.annotate(f'{annot_value}',
                    (i.get_xy()[0], (i.get_xy()[1] + i.get_height()) + 0.01),
                    ha='center',
                    va='bottom')

    f.savefig(f'./figs/{filename}.eps', bbox_inches='tight', dpi=300)
    f.savefig(f'./figs/{filename}.png', bbox_inches='tight', dpi=300)

    plt.close(f)


plot_ottv_before_after_bar(video_avgs_8weeks, '03_ottv_before_after_bar')

#####

# Figure 2, panel (b)
ottv_bitrates = {'level_vidhulu': 0.97,
                 'level_vidnetflix': 1.36,
                 'level_vidslingtv': 1.41,
                 'level_vidyoutube': 0.88}
ottv_vars = list(ottv_bitrates.keys())
fframe = weekly_dat_8weeks.loc[(weekly_dat_8weeks.weeks_to_cut != 0)].copy()
fframe = fframe.groupby(['customer_key', 'period_desc'])[ottv_vars].mean().reset_index()
viewing_estimates = []
for ottv in ottv_vars:
    # Make copy of the data
    iframe = fframe.copy()

    # Flag users of the application
    iframe['use_before'] = 0
    iframe.loc[(iframe[ottv] > 0) & (iframe.period_desc == 'before'), 'use_before'] = 1
    iframe['use_before_max'] = iframe.groupby('customer_key').use_before.transform('max')

    iframe['use_after'] = 0
    iframe.loc[(iframe[ottv] > 0) & (iframe.period_desc == 'after'), 'use_after'] = 1
    iframe['use_after_max'] = iframe.groupby('customer_key').use_after.transform('max')

    cols_to_keep = ['customer_key', 'period_desc', 'use_before_max', 'use_after_max', ottv]
    iframe = iframe.loc[:, cols_to_keep].copy()

    # Estimate viewing for people who viewed before and after
    iframe = iframe.loc[(iframe.use_before_max == 1) & (iframe.use_after_max == 1)].copy()
    avg_daily_gb = iframe.groupby('period_desc')[ottv].mean()
    avg_daily_mins = (avg_daily_gb / ottv_bitrates[ottv]) * 60
    avg_daily_mins = avg_daily_mins.reset_index()

    # Save results
    viewing_estimates.append((ottv,
                              avg_daily_mins.loc[avg_daily_mins.period_desc == 'before', ottv].values[0],
                              avg_daily_mins.loc[avg_daily_mins.period_desc == 'after', ottv].values[0]))

vframe = pd.DataFrame(viewing_estimates, columns=['ottv_app', 'before', 'after'])
vframe = vframe.set_index('ottv_app').stack().reset_index().rename(columns={'level_1': 'period_desc', 0: 'mins'})


def plot_ottv_viewing_bar(pframe, filename):

    pframe = pframe.sort_values(['ottv_app', 'period_desc'], ascending=[True, False])
    pframe['pct_change'] = pframe.groupby('ottv_app').mins.pct_change() * 100
    pct_change_lookup = pframe.dropna(subset=['pct_change']).reset_index(drop=True).to_dict(orient='index')

    pframe['ottv_app'] = pframe['ottv_app'].str[9:].str.title()
    pframe.loc[pframe.ottv_app == 'Slingtv', 'ottv_app'] = 'Sling TV'
    pframe.loc[pframe.ottv_app == 'Youtube', 'ottv_app'] = 'YouTube'

    f, ax = plt.subplots()

    sns.barplot(data=pframe, x='ottv_app', y='mins', hue='period_desc', hue_order=['before', 'after'], ax=ax,
                palette='Greys')

    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    ax.yaxis.grid(linewidth=0.5, linestyle=':')
    ax.set_axisbelow(True)

    ax.set_ylabel('Estimated Daily Viewing (Minutes)')
    ax.set_xlabel('')

    ax.set_ylim([0, 125])

    patches = ax.patches[::pframe.ottv_app.nunique()]
    ax.legend(patches,
              ['Before MSO TV Drop',
               'After MSO TV Drop'],
              loc='upper left')

    for idx, i in enumerate(ax.patches[pframe.ottv_app.nunique():]):
        annot_value = int(pct_change_lookup[idx]['pct_change'])
        if annot_value > 0:
            annot_value = f'+{annot_value}%'
        else:
            annot_value = f'{annot_value}%'

        ax.annotate(f'{annot_value}',
                    (i.get_xy()[0], (i.get_xy()[1] + i.get_height()) + 0.05),
                    ha='center',
                    va='bottom')

    f.savefig(f'./figs/{filename}.eps', bbox_inches='tight', dpi=300)
    f.savefig(f'./figs/{filename}.png', bbox_inches='tight', dpi=300)

    plt.close(f)


plot_ottv_viewing_bar(vframe, '04_ottv_min_estimates')

#####

# Want to compare daily usage 8 weeks before and after
weekly_dat_8weeks['period_desc'] = 'week_of'
weekly_dat_8weeks.loc[(weekly_dat_8weeks.weeks_to_cut < 0), 'period_desc'] = 'before'
weekly_dat_8weeks.loc[(weekly_dat_8weeks.weeks_to_cut > 0), 'period_desc'] = 'after'
print(weekly_dat_8weeks.period_desc.value_counts(dropna=False).sort_index())

tot_gb_8week_avg = weekly_dat_8weeks.groupby(['customer_key', 'period_desc']).tot_gb.mean()
tot_gb_8week_avg = tot_gb_8week_avg.unstack('period_desc')
tot_gb_8week_avg['diff'] = tot_gb_8week_avg['after'] - tot_gb_8week_avg['before']
tot_gb_8week_avg['growth'] = tot_gb_8week_avg['diff'] / tot_gb_8week_avg['before']


#####

# FIGURE A3
type_avgs_8weeks = weekly_dat_8weeks.groupby(['customer_key', 'period_desc'])[level_vars].mean().reset_index()
type_avgs_8weeks = type_avgs_8weeks.groupby('period_desc')[level_vars].mean()
type_avgs_8weeks = type_avgs_8weeks.transpose()
type_avgs_8weeks['gb_diff'] = type_avgs_8weeks.after - type_avgs_8weeks.before
type_avgs_8weeks['pct_change'] = type_avgs_8weeks.gb_diff / type_avgs_8weeks.before
type_avgs_8weeks['pct_total_change'] = type_avgs_8weeks.gb_diff / type_avgs_8weeks.gb_diff.sum()
type_avgs_8weeks = type_avgs_8weeks.reset_index()
type_avgs_8weeks['usage_cat'] = type_avgs_8weeks['index'].str[6:].str.title()


def plot_before_after_bar(pframe, filename):

    pframe = pframe.loc[:, ['usage_cat', 'before', 'after']].copy()
    pframe = pframe.set_index('usage_cat').stack().reset_index().rename(columns={0: 'gb'})
    pframe = pframe.sort_values(['usage_cat', 'period_desc'], ascending=[True, False])
    pframe['pct_change'] = pframe.groupby('usage_cat').gb.pct_change() * 100
    pct_change_lookup = pframe.dropna(subset=['pct_change']).reset_index(drop=True).to_dict(orient='index')

    f, ax = plt.subplots()

    sns.barplot(data=pframe, x='usage_cat', y='gb', hue='period_desc', hue_order=['before', 'after'], ax=ax,
                palette='Greys')

    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    ax.yaxis.grid(linewidth=0.5, linestyle=':')
    ax.set_axisbelow(True)

    ax.set_ylabel('Average Daily Usage (Gigabytes)')
    ax.set_xlabel('')

    patches = ax.patches[::pframe.usage_cat.nunique()]
    ax.legend(patches,
              ['Before MSO TV Drop',
               'After MSO TV Drop'],
              loc='upper left')

    for idx, i in enumerate(ax.patches[pframe.usage_cat.nunique():]):
        annot_value = int(pct_change_lookup[idx]['pct_change'])
        if annot_value > 0:
            annot_value = f'+{annot_value}%'
        else:
            annot_value = f'{annot_value}%'

        ax.annotate(f'{annot_value}',
                    (i.get_xy()[0], (i.get_xy()[1] + i.get_height()) + 0.025),
                    ha='center',
                    va='bottom')

    f.savefig(f'./figs/{filename}.eps', bbox_inches='tight', dpi=300)
    f.savefig(f'./figs/{filename}.png', bbox_inches='tight', dpi=300)

    plt.close(f)


plot_before_after_bar(type_avgs_8weeks, '02_before_after_bar')

#####

# FIGURE A5, panel (a)
before_cut = cord_cutters.loc[(cord_cutters.date < cord_cutters.min_dt_wo_video)].copy()
after_cut = cord_cutters.loc[(cord_cutters.date >= cord_cutters.min_dt_wo_video)].copy()

before_avg_rev = before_cut.groupby('customer_key').monthly_amt.mean().reset_index()
after_avg_rev = after_cut.groupby('customer_key').monthly_amt.mean().reset_index()
avg_rev = pd.merge(before_avg_rev, after_avg_rev, on='customer_key', suffixes=('_before', '_after'))


def plot_before_after_rev_kde(pframe1, pframe2, filename):

    f, ax = plt.subplots(figsize=(8, 6))

    sns.kdeplot(pframe1, shade=True, bw='silverman', ax=ax,
                color=sns.color_palette('Set2')[0], clip=(0, 400))
    sns.kdeplot(pframe2, shade=True, bw='silverman', ax=ax,
                color=sns.color_palette('Set2')[1], clip=(0, 400))

    ax.set_xlabel('Total Monthly Revenue ($)')
    ax.set_ylabel('Density')
    ax.set_xlim([0, 400])

    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    ax.yaxis.grid(linewidth=0.5, linestyle=':')
    ax.xaxis.grid(linewidth=0.5, linestyle=':')
    ax.set_axisbelow(True)

    ax.annotate('After MSO TV Drop', xy=(85, 0.015), xytext=(100, 0.017),
                arrowprops=dict(facecolor='black',
                                shrink=0.025,
                                width=0.05,
                                headwidth=5,
                                headlength=5))

    ax.annotate('Before MSO TV Drop', xy=(235, 0.005), xytext=(250, 0.007),
                arrowprops=dict(facecolor='black',
                                shrink=0.025,
                                width=0.05,
                                headwidth=5,
                                headlength=5))

    f.savefig('./figs/{0}.eps'.format(filename), bbox_inches='tight', dpi=300)
    f.savefig('./figs/{0}.png'.format(filename), bbox_inches='tight', dpi=300)

    plt.close(f)


plot_before_after_rev_kde(avg_rev.monthly_amt_before.values,
                          avg_rev.monthly_amt_after.values,
                          '03_before_after_rev_kde')

#####

# FIGURE A5, panel (b)
ottv_vars = ['level_vidnetflix',
             'level_vidhbo',
             'level_vidhulu',
             'level_vidslingtv']
other_video_vars = ['level_vidamazon',
                    'level_vidyoutube',
                    'level_vidflash',
                    'level_vidother']
before_ottv_avgs = before_cut.groupby('customer_key')[ottv_vars + other_video_vars].mean().reset_index()
after_ottv_avgs = after_cut.groupby('customer_key')[ottv_vars + other_video_vars].mean().reset_index()
ottv_avgs = pd.merge(before_ottv_avgs, after_ottv_avgs, on='customer_key', suffixes=('_before', '_after'))

ottv_avgs['netflix_price'] = 9.99
ottv_avgs['hbo_price'] = 15.0
ottv_avgs['hulu_price'] = 7.99
ottv_avgs['slingtv_price'] = 30.0
price_vars = ['netflix_price', 'hbo_price', 'hulu_price', 'slingtv_price']

before_dummies = [f'before_dum_{p}' for p in ottv_vars]
after_dummies = [f'after_dum_{p}' for p in ottv_vars]

for ov in ottv_vars:
    before_dummy = f'before_dum_{ov}'
    after_dummy = f'after_dum_{ov}'

    ottv_avgs[before_dummy] = (ottv_avgs[f'{ov}_before'] > 0)
    ottv_avgs[after_dummy] = (ottv_avgs[f'{ov}_after'] > 0)

ottv_avgs['ottv_amt_before'] = (ottv_avgs[before_dummies] * ottv_avgs[price_vars].values).sum(axis=1)
ottv_avgs['ottv_amt_after'] = (ottv_avgs[after_dummies] * ottv_avgs[price_vars].values).sum(axis=1)


def plot_before_after_ottv_rev_kde(pframe1, pframe2, filename):

    f, ax = plt.subplots(figsize=(8, 6))

    sns.kdeplot(pframe1, shade=True, bw='silverman', ax=ax,
                color=sns.color_palette('Set2')[0], clip=(0, 75))
    sns.kdeplot(pframe2, shade=True, bw='silverman', ax=ax,
                color=sns.color_palette('Set2')[1], clip=(0, 75))

    ax.set_xlabel('Total Monthly Expenditure ($)')
    ax.set_ylabel('Density')
    ax.set_xlim([0, 75])

    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    ax.yaxis.grid(linewidth=0.5, linestyle=':')
    ax.xaxis.grid(linewidth=0.5, linestyle=':')
    ax.set_axisbelow(True)

    ax.annotate('After MSO TV Drop', xy=(50, 0.01), xytext=(52, 0.012),
                arrowprops=dict(facecolor='black',
                                shrink=0.025,
                                width=0.05,
                                headwidth=5,
                                headlength=5))

    ax.annotate('Before MSO TV Drop', xy=(25, 0.0185), xytext=(27, 0.0205),
                arrowprops=dict(facecolor='black',
                                shrink=0.025,
                                width=0.05,
                                headwidth=5,
                                headlength=5))

    f.savefig('./figs/{0}.eps'.format(filename), bbox_inches='tight', dpi=300)
    f.savefig('./figs/{0}.png'.format(filename), bbox_inches='tight', dpi=300)

    plt.close(f)


plot_before_after_ottv_rev_kde(ottv_avgs.ottv_amt_before.values,
                               ottv_avgs.ottv_amt_after.values,
                               '10a_ottv_amt_before_after_kde')

#####

# FIGURE A6, panel (a)
before_cut['tier_idx'] = np.nan
before_cut.loc[before_cut.tier < 3, 'tier_idx'] = 0
before_cut.loc[before_cut.tier == 3, 'tier_idx'] = 1
before_cut.loc[before_cut.tier > 3, 'tier_idx'] = 2

after_cut['tier_idx'] = np.nan
after_cut.loc[after_cut.tier < 3, 'tier_idx'] = 0
after_cut.loc[after_cut.tier == 3, 'tier_idx'] = 1
after_cut.loc[after_cut.tier > 3, 'tier_idx'] = 2

before_tiers = before_cut.query('tier < 6').groupby('customer_key').tier_idx.max().value_counts(normalize=True)
after_tiers = after_cut.query('tier < 6').groupby('customer_key').tier_idx.max().value_counts(normalize=True)

tiers = pd.merge(before_tiers, after_tiers, left_index=True, right_index=True, suffixes=('_before', '_after'))
tiers_stacked = tiers.stack().reset_index().rename(columns={'level_0': 'tier', 'level_1': 'period_desc', 0: 'prop'})
tiers_stacked['tier'] = tiers_stacked.tier.astype(int)


def plot_before_after_plan_choices(pframe, filename):

    pframe = pframe.sort_values(['tier', 'period_desc'], ascending=[True, False])
    pframe['pct_change'] = pframe.groupby('tier').prop.pct_change() * 100

    f, ax = plt.subplots()

    sns.barplot(data=pframe, x='tier', y='prop', hue='period_desc', hue_order=['tier_idx_before', 'tier_idx_after'],
                ax=ax, palette='Set2', orient='v')

    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    ax.yaxis.grid(linewidth=0.5, linestyle=':')
    ax.set_axisbelow(True)

    ax.set_ylabel('Proportion of Cord-Cutters')
    ax.set_xlabel('Internet Service Tier')

    ax.set_ylim([0, .65])

    ax.set_xticklabels(['Below Median', 'Median', 'Above Median'])

    patches = ax.patches[::pframe.tier.nunique()]
    ax.legend(patches,
              ['Before MSO TV Drop',
               'After MSO TV Drop'],
              loc='upper right')

    f.savefig(f'./figs/{filename}.eps', bbox_inches='tight', dpi=300)
    f.savefig(f'./figs/{filename}.png', bbox_inches='tight', dpi=300)

    plt.close(f)


plot_before_after_plan_choices(tiers_stacked, '04_before_after_tier_dist')

#####

# FIGURE A6, panel (b)
before_tiers_max = before_cut.query('tier < 6').groupby('customer_key').tier_idx.max().reset_index()
after_tiers_max = after_cut.query('tier < 6').groupby('customer_key').tier_idx.max().reset_index()
tiers_max = pd.merge(before_tiers_max, after_tiers_max, on='customer_key', suffixes=('_before', '_after'))
plan_trans_mat = pd.crosstab(tiers_max.tier_idx_before,
                             tiers_max.tier_idx_after,
                             values='customer_key',
                             aggfunc='count',
                             normalize='index')


def plot_trans_mat_heatmap(mat, filename):

    f, ax = plt.subplots()

    sns.heatmap(mat,
                cbar=True,
                cmap='YlGnBu',
                ax=ax,
                vmin=0,
                vmax=1,
                annot=True,
                fmt='.2f',
                cbar_kws={'label': 'Transition Probability'})

    ax.set_yticklabels(['Below Median', 'Median', 'Above Median'], rotation='horizontal')
    ax.set_xticklabels(['Below Median', 'Median', 'Above Median'])

    ax.set_xlabel(r'Tier After Pay TV Drop')
    ax.set_ylabel(r'Tier Before Pay TV Drop')

    ax.set_ylim([0, 3])
    ax.set_xlim([0, 3])

    f.savefig('./figs/{0}.eps'.format(filename), bbox_inches='tight', dpi=300)
    f.savefig('./figs/{0}.png'.format(filename), bbox_inches='tight', dpi=300)

    plt.close(f)


plot_trans_mat_heatmap(plan_trans_mat, '05_plan_trans_mat_row')
